/**
 * 2017年5月21日
 */
package cn.edu.bjtu.datasource.lsp;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

import org.datavec.api.records.reader.impl.regex.RegexLineRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import cn.edu.bjtu.core.SentenceDetect;



/**
 * 用于训练CNN的带类标的句子迭代器
 * LabeledSentenceProvider这个限制了返回的类标只能是一个
 * @author Alex
 *TrainCNNSentenceProvider extends  WechatLineSentenceRecordReader implements LabeledSentenceProvider
 *这样写不可以,因为两个nextSentence返回值不同,不能重载
 */
public class TrainCNNSentenceProvider extends RegexLineRecordReader implements LabeledSentenceProvider {
	protected Logger logger  = LoggerFactory.getLogger(TrainCNNSentenceProvider.class);
	protected AtomicInteger filterCount = new AtomicInteger(0);
	private static final long serialVersionUID = 61137403720357384L;
	/**
	 * @param regex
	 * @param skipNumLines
	 */
	protected TrainCNNSentenceProvider(String regex, int skipNumLines) {
		super(regex, skipNumLines);
		for(int i=1;i<=13;i++){
			_allLabels.add(String.valueOf(i));
		}
	}

	public TrainCNNSentenceProvider(int skip) {
		//      $  类标   空格    id与url(舍弃) 内容       日期(舍弃)
		this("\\$?(.*?)\t(?:.*)\\t(.*)\t(?:.*)",skip);
	}
	public TrainCNNSentenceProvider() {
		this(0);
	}
	
	protected String labels = null;
	protected String content = null;
	protected List<String> _allLabels = new ArrayList<>(15);
	
	protected boolean shouldKeep(String strName){
		if(!SentenceDetect.get().isFullfield(strName)){
			filterCount.incrementAndGet();
			return false;
		}else{
			return true;
		}
		
	}

	/**
	 * @return Pair: sentence/document text and label
	 */
	public Pair<String, String> nextSentence() {
		List<Writable> l = next();
		labels = l.get(0).toString();
		content = l.get(1).toString().replaceAll("[\\\\r\\\\n\\s]", "");
		
		while(!shouldKeep(content) && hasNext()){
			List<Writable> temp = next();
			labels = temp.get(0).toString();
			content = temp.get(1).toString();
		}
		if(!shouldKeep(content)){
			logger.info("At split {} at lineIndex {} ,content is {} not statisfied",this.splitIndex,this.lineIndex,content);
			throw new RuntimeException("Illegal DataSet");
		}
		int idx = labels.indexOf(',');
		String label = labels;
		if(idx != -1){
			label = labels.substring(0, idx); 
		}
		return Pair.makePair(content, label);
	}

	@Override
	public int totalNumSentences() {
		System.out.println("total num sentence invoked!");
		return 10000;
	}

	@Override
	public List<String> allLabels() {
		return java.util.Collections.unmodifiableList(_allLabels);
	}

	@Override
	public int numLabelClasses() {
		return _allLabels.size();
	}

	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	


	
	
	
	
	
	
	
	

}
